from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_text_splitters import Language
from langchain_community.document_loaders.parsers import LanguageParser
from langchain_community.document_loaders.generic import GenericLoader
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import (
    HumanMessage, AIMessage, SystemMessage
)
import utils
import torch
from langchain_huggingface import HuggingFaceEmbeddings
import os
import re
import gradio as gr
import requests
import json
import subprocess
import threading
import time

DEBUG = os.environ.get("DEBUG", "false") in ("true", "1", "yes")
# webui_command = ["python", "api_server.py"]
# webui_process = subprocess.Popen(
#     webui_command, text=True)

model_name = "hf-models/glm-4-9b-chat"

api_server_command = [
    "python",
    "-m",
    "vllm.entrypoints.openai.api_server",
    "--model",
    model_name,
    "--dtype",
    "float16",
    "--api-key",
    "",
    "--tensor-parallel-size",
    "4",
    "--trust-remote-code",
    "--gpu-memory-utilization",
    "0.8",
    "--disable-log-requests",
    "--disable-log-stats",
    "--port",
    "8000",
    # 多卡跑多 模型时, vllm GPU blocks: 0 https://github.com/vllm-project/vllm/issues/2248
    # "--enforce-eager"
]
api_process = subprocess.Popen(
    api_server_command, text=True)

api_server_ready = False

print("开始启动 api 服务")

GITEE_ACCESS_TOKEN = os.environ.get("GITEE_ACCESS_TOKEN", "")

if (not GITEE_ACCESS_TOKEN):
    print("GITEE_ACCESS_TOKEN 环境变量不存在")

base_url = "http://127.0.0.1:8000/v1/"
repo_path = "/repos"
git_clone_ai_doc_command = ["git", "clone",
                            f"https://oauth2:{GITEE_ACCESS_TOKEN}@gitee.com/gitee-ai/docs.git",  "--depth=1", "--single-branch", repo_path+"/ai.gitee.com"]

git_clone_gitee_doc_command = ["git", "clone",
                               f"https://oauth2:{GITEE_ACCESS_TOKEN}@gitee.com/oschina/gitee-help-center.git", "--depth=1", "--single-branch", repo_path+"/help.gitee.com"]


def clone_doc_repo():
    try:
        subprocess.run(
            git_clone_ai_doc_command, text=True)
        subprocess.run(
            git_clone_gitee_doc_command, text=True)
        # print("stdout:", git_process.stdout)
        # print("stderr:", git_process.stderr)
    except Exception as e:
        print("克隆仓库发生错误:", e)


# time.sleep(180)
db = ""
retriever = ""

bge_model_name = "hf-models/bge-m3"
bge_model_kwargs = {'device': 'cuda:1'}  # 'torch_dtype': torch.float16
bge_encode_kwargs = {'normalize_embeddings': False}
hfEmbeddings = HuggingFaceEmbeddings(
    model_name=bge_model_name,
    model_kwargs=bge_model_kwargs,
    encode_kwargs=bge_encode_kwargs,
)

# 多卡跑多 模型时, vllm GPU blocks: 0 https://github.com/vllm-project/vllm/issues/2248
torch.cuda.empty_cache()
print('向量模型准备完成')


# 修改文档中的链接。包含 slug 则由 ai 处理.

def update_sources(documents):
    for doc in documents:
        source = doc.metadata.get('source')
        # /repos/ai.gitee.com/docs/xxx
        # /repos/help.gitee.com/docs/xxx
        if source and source.startswith(repo_path+'/'):
            # 将 /repos/ 改为 https://
            source = source.replace(repo_path+'/', 'https://', 1)
            # 去掉末尾的 .md
            if source.endswith('.md'):
                source = source[:-3]
            # 针对 /help.gitee.com 去掉 /docs
            if 'help.gitee.com/docs' in source:
                source = source.replace('/docs', '', 1)

        # 检查 page_content 中是否有 slug
        slug_match = re.search(r'slug: (/.+)', doc.page_content)
        if slug_match:
            slug = slug_match.group(1)
            doc.page_content = f'参考文档链接: https://help.gitee.com{slug}\n\n{doc.page_content}'
        elif source:
            # 拼接 source 到 page_content 最开始
            doc.page_content = f'参考文档链接: {source}\n\n{doc.page_content}'

        # 从 metadata 中移除 source
        if 'source' in doc.metadata:
            del doc.metadata['source']

    return documents


def update_doc_db():
    global db
    global retriever
    subprocess.run(["rm", "-rf", repo_path])
    clone_doc_repo()
    loader = GenericLoader.from_filesystem(
        repo_path,
        glob="**/*",
        suffixes=[".md", ".txt"],
        exclude=["**/non-utf8-encoding.py", ".git/**"],
        # 不指定 language=Language.PYTHON 自动从文件后缀推断语言
        parser=LanguageParser(parser_threshold=0),
    )
    documents = loader.load()
    documents = update_sources(documents)
    python_splitter = RecursiveCharacterTextSplitter.from_language(
        # ValueError: Batch size 51030 exceeds maximum batch size 41666
        language=Language.MARKDOWN, chunk_size=1800, chunk_overlap=400
    )
    docs = python_splitter.split_documents(documents)
    print("召回内容", docs[5:10])
    print("docs 分块长度", len(docs))
    # db = Chroma.from_documents(docs, OpenAIEmbeddings(disallowed_special=(), api_key="EMPTY", base_url=base_url, model="hf-models/bge-m3", timeout=300,
    #                                                   tiktoken_enabled=False, show_progress_bar=True, chunk_size=8192), persist_directory="./chroma_db")  # 非 OpenAI 实现，tiktoken_enabled 必须设置为 False
    db = Chroma.from_documents(docs, hfEmbeddings)
    retriever = db.as_retriever(
        # search_type="mmr",  # 值为 'similarity', 'similarity_score_threshold', 'mmr')
        search_type="similarity",
        search_kwargs={"k": 13}
    )
    print('向量库已生成')
    torch.cuda.empty_cache()

# print(documents[100:500])


update_doc_db()
# chunk_size 块字符长度, 召回的每一项中字符串的长度, 代码文件需要比较大才能召回文件内容完整。300 行代码可能有六千字符
#  chunk_overlap (块重叠) 指相邻块之间的重叠部分，确保在块分块后， 块之间的内容不会被忽略

# def read_process_output(process):
#     """Reads the process output and prints it."""
#     while True:
#         output = process.stdout.readline()
#         if output == '' and process.poll() is not None:
#             break
#         if output:
#             print("API Server: "+output.strip())


# prompt = ChatPromptTemplate.from_messages(
#     [
#         (
#             "system", "你是一个 Gitee AI 文档助手，参考上下文中的文档内容，回答用户问题。你的名字叫马建仓。你幽默风趣，遵守中国法律，不回复任何敏感、违法、违反道德的问题。:\n\n{context}",
#         ),
#         ("placeholder", "{chat_history}"),
#         ("user", "{input}"),
#     ]
# )


# retriever_chain = create_history_aware_retriever(llm, retriever, prompt)
# document_chain = create_stuff_documents_chain(llm, prompt)

# qa = create_retrieval_chain(retriever_chain, document_chain)

# question = """
# Grimoire Gitee AI 快速开始
# """
# result = qa.invoke({"input": question})
# print(result["answer"])


# for chunk in qa.stream({"input": question}):
#     if (chunk.get("context")):
#         print("召回内容", chunk.get("context"))
#     if answer_chunk := chunk.get("answer"):
#         print(f"{answer_chunk}", end="")


# system_prompt = {
#     "role": "system",
#     "content":
#     "你是一个 Gitee AI 文档助手，参考上下文中文档内容，回答用户问题。你的名字叫马建仓。你幽默风趣，遵守中国法律，不回复任何敏感、违法、违反道德的问题。"
# }

llm = ChatOpenAI(model=model_name, api_key="EMPTY", base_url=base_url,
                 # glm4 [151329, 151336, 151338] qwen2[151643, 151644, 151645]
                 callbacks=[StreamingStdOutCallbackHandler()],
                 streaming=True, temperature=0.2, presence_penalty=1.2, top_p=0.95, extra_body={"stop_token_ids": [151329, 151336, 151338]})


def chat_fn(message, history):
    gen_time_start = time.time()
    global api_server_ready
    global db
    global retriever

    if (not api_server_ready):
        api_server_ready = utils.is_port_open(base_url)
        yield "API 服务正在启动服务，请重试..."
        raise gr.Info("API 服务正在启动服务，请稍后...")
    if (not message):
        raise gr.Error("请输入您的问题！")
    if len(message) > 1000:
        raise gr.Error("输入长度不能超过 1000 哦！")
    print("\ninput:", message)
    if (not db):
        update_doc_db()
        # raise gr.Info("正在更新向量数据库，请稍后，过程大约持续三分钟")
    history_openai_format = []
    search_res = str(retriever.invoke(message))
    if (DEBUG):
        print("搜索结果:", search_res)
    history_openai_format.append(SystemMessage(
        content=f"""
- 【你十分幽默风趣活泼，名叫马建仓，是 Gitee 吉祥物，代表 Gitee 企业形象。】
- 你的回答参考 markdown 内容，紧扣用户问题，使用 markdown 格式回答。
- 回答时，确保
    文档中有 https 链接（图片、链接等），如有必要，请提供完整，可以使用 markdown 渲染，相对路径和不相关图片请勿提供。
    如有必要, 回答末尾可提供前少量参考文档 https:// 链接, 方便用户了解更多, 【注意: 请勿编造链接，请勿总结混合链接，没有就不要提供。】
    遵守中国法律，不回复任何敏感、违法、违反道德的问题。
    如果你不知道答案，就不要捏造假内容。
- 内容为: {search_res}
"""))
    if (DEBUG):
        print("history", str(history))
    for human, ai in history:
        # 防止对话 error 出现 none  pydantic.v1.error_wrappers.ValidationError: 1 validation error for AIMessage
        human = human if isinstance(human, str) else ""
        ai = ai if isinstance(ai, str) else ""
        history_openai_format.append(HumanMessage(content=human))
        history_openai_format.append(AIMessage(content=ai))
    history_openai_format.append(HumanMessage(message))
    full_answer = ""
    for response in llm.stream(history_openai_format):
        full_answer += response.content
        if (len(full_answer) == 1):
            gen_time_end = time.time()
            print("首字消耗时间:", gen_time_end - gen_time_start)
        yield full_answer
    # full_answer = ""
    # for chunk in llm.stream(history_openai_format):
    #     # if (chunk.get("context")):
    #     # print("召回内容", chunk.get("context"))
    #     if answer_chunk := chunk.get("answer"):
    #         full_answer += answer_chunk
    #         print(f"{answer_chunk}", end="")
    # messages = [system_prompt]
    # for msg in history:
    #     messages.append({"role": "user", "content": str(msg[0])})
    #     messages.append({"role": "assistant", "content": str(msg[1])})
    # messages.append({"role": "user", "content": str(message)})
    # complete_message = ''
    # res = openai_api_request.simple_chat(messages=messages, use_stream=True)
    # for chunk in res:
    #     delta_content = chunk.choices[0].delta.content
    #     complete_message += delta_content
    #     # print(delta_content, end='')  # 不换行拼接输出当前块的内容
    #     yield complete_message  # gradio 需要返回完整可迭代内容
    # print(message)
    # print("\nComplete message:", complete_message)


chatbot = gr.Chatbot(height=550, label="Gitee 马建仓")


def toast_info_update_message():
    gr.Info("已请求更新")


prompt = gr.Textbox(
    placeholder="输入你的问题...", container=False, scale=7)

with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as demo:
    gr.HTML("<div><div>")
    gr.HTML("<div>哈喽！我是马建仓，你在使用 Gitee AI 过程中有任何问题都可以找我！<div>")
    gr.Markdown("- 注: AI 输出，仅用于文档参考，不代表官方观点")
    chat = gr.ChatInterface(chat_fn,
                            submit_btn="提问",
                            chatbot=chatbot,
                            textbox=prompt,
                            clear_btn="清空对话",
                            
                            stop_btn="暂停",
                            undo_btn=None,
                            retry_btn="重试提问",
                            examples=["Gitee AI 是什么?",
                                      "Gitee AI Serverless API 是什么, 在哪里体验购买",
                                      "为什么要购买企业版?",
                                      "如何购买企业版?",
                                      "应用能干嘛？",
                                      "如何创建应用？",
                                      "介绍模型引擎", "一个应用会有哪些状态，分别是什么意思？", "应用中如何使用天数智芯算力？", "应用的环境变量和秘钥有何区别, 如何在代码中使用"],
                            )
    update_button = gr.Button("点击更新文档向量库", elem_id="update_button")

    def on_button_click():
        update_doc_db()
        toast_info_update_message()
    update_button.click(fn=on_button_click)
    demo.css = """
    #update_button {
        display: none;
        width: 200px;
    }
    .message img {
        width: 70%;
        max-width: 100%;
    }
    """
demo.queue(default_concurrency_limit=4)
demo.launch(show_api=False)
